/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.test.iterative;

import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.examples.java.graph.ConnectedComponents;
import org.apache.flink.test.testdata.ConnectedComponentsData;
import org.apache.flink.test.util.JavaProgramTestBase;
import org.apache.flink.util.Collector;

import java.io.BufferedReader;

/**
 * Tests a bug that prevented that the solution set can be on both sides of the match/cogroup function.
 */
public class ConnectedComponentsWithSolutionSetFirstITCase extends JavaProgramTestBase {

	private static final long SEED = 0xBADC0FFEEBEEFL;

	private static final int NUM_VERTICES = 1000;

	private static final int NUM_EDGES = 10000;

	protected String verticesPath;
	protected String edgesPath;
	protected String resultPath;

	@Override
	protected void preSubmit() throws Exception {
		verticesPath = createTempFile("vertices.txt", ConnectedComponentsData.getEnumeratingVertices(NUM_VERTICES));
		edgesPath = createTempFile("edges.txt", ConnectedComponentsData.getRandomOddEvenEdges(NUM_EDGES, NUM_VERTICES, SEED));
		resultPath = getTempFilePath("results");
	}

	@Override
	protected void testProgram() throws Exception {
		// set up execution environment
		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

		// read vertex and edge data
		DataSet<Tuple1<Long>> vertices = env.readCsvFile(verticesPath).types(Long.class);

		DataSet<Tuple2<Long, Long>> edges = env.readCsvFile(edgesPath).fieldDelimiter(" ").types(Long.class, Long.class)
				.flatMap(new ConnectedComponents.UndirectEdge());

		// assign the initial components (equal to the vertex id)
		DataSet<Tuple2<Long, Long>> verticesWithInitialId = vertices.map(new ConnectedComponentsITCase.DuplicateValue<Long>());

		// open a delta iteration
		DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iteration =
				verticesWithInitialId.iterateDelta(verticesWithInitialId, 100, 0);

		// apply the step logic: join with the edges, select the minimum neighbor, update if the component of the candidate is smaller
		DataSet<Tuple2<Long, Long>> minNeighbor = iteration.getWorkset()
				.join(edges).where(0).equalTo(0).with(new ConnectedComponents.NeighborWithComponentIDJoin())
				.groupBy(0).aggregate(Aggregations.MIN, 1);

		DataSet<Tuple2<Long, Long>> updatedIds = iteration.getSolutionSet()
				.join(minNeighbor).where(0).equalTo(0).with(new UpdateComponentIdMatchMirrored());

		// close the delta iteration (delta and new workset are identical)
		DataSet<Tuple2<Long, Long>> result = iteration.closeWith(updatedIds, updatedIds);

		result.writeAsCsv(resultPath, "\n", " ");

		// execute program
		env.execute("Connected Components Example");
	}

	@Override
	protected void postSubmit() throws Exception {
		for (BufferedReader reader : getResultReader(resultPath)) {
			ConnectedComponentsData.checkOddEvenResult(reader);
		}
	}

	// --------------------------------------------------------------------------------------------
	//  Classes and methods for the test program
	// --------------------------------------------------------------------------------------------

	@FunctionAnnotation.ForwardedFieldsSecond("*")
	private static final class UpdateComponentIdMatchMirrored
			implements FlatJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
		private static final long serialVersionUID = 1L;

		@Override
		public void join(
				Tuple2<Long, Long> current,
				Tuple2<Long, Long> candidate,
				Collector<Tuple2<Long, Long>> out) throws Exception {

			if (candidate.f1 < current.f1) {
				out.collect(candidate);
			}

		}
	}
}
